# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta
import sympy as sm
import numpy as np
from hysop.tools.numpywrappers import npw
from hysop.tools.htypes import (
check_instance,
to_tuple,
InstanceOf,
first_not_None,
to_set,
)
from hysop.tools.decorators import debug
from hysop.tools.sympy_utils import get_derivative_variables, SetupExprI
from hysop.fields.continuous_field import Field
from hysop.fields.discrete_field import DiscreteField, DiscreteScalarFieldView
from hysop.fields.field_requirements import DiscreteFieldRequirements
from hysop.parameters.scalar_parameter import ScalarParameter
from hysop.topology.topology import TopologyView
from hysop.operator.directional.directional import DirectionalOperatorBase
from hysop.backend.device.codegen.base.utils import SortedDict
from hysop.core.memory.memory_request import MemoryRequest
from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
from hysop.symbolic import (
Dummy,
time_symbol,
space_symbols,
dspace_symbols,
local_indices_symbols,
global_indices_symbols,
)
from hysop.symbolic.relational import Assignment, AugmentedAssignment
from hysop.symbolic.misc import (
TimeIntegrate,
ApplyStencil,
CodeSection,
MutexLock,
MutexUnlock,
Cast,
)
from hysop.symbolic.tmp import TmpScalar
from hysop.symbolic.array import SymbolicArray, SymbolicBuffer, IndexedBuffer
from hysop.symbolic.field import AppliedSymbolicField, SymbolicDiscreteField
from hysop.symbolic.parameter import SymbolicScalarParameter
from hysop.constants import (
ComputeGranularity,
SpaceDiscretization,
TranspositionState,
DirectionLabels,
SymbolicExpressionKind,
)
from hysop.numerics.odesolvers.runge_kutta import (
TimeIntegrator,
ExplicitRungeKutta,
Euler,
RK2,
RK3,
RK4,
)
from hysop.numerics.interpolation.interpolation import (
MultiScaleInterpolation,
Interpolation,
)
from hysop.numerics.stencil.stencil_generator import (
StencilGenerator,
CenteredStencilGenerator,
MPQ,
)
ValidExpressions = (Assignment,)
[docs]
class ExprDiscretizationInfo:
SimpleCounterTypes = (
SymbolicArray,
SymbolicBuffer,
)
IndexedCounterTypes = (DiscreteScalarFieldView,)
def __new__(cls, **kwds):
return super().__new__(cls, **kwds)
def __init__(self, **kwds):
"""
Helper class to store information about discretized symbolic expressions.
"""
super().__init__(**kwds)
self.read_counter = SortedDict()
self.write_counter = SortedDict()
self.parameters = SortedDict()
[docs]
def read(self, obj, index=None, count=1):
check_instance(count, int)
if isinstance(obj, self.IndexedCounterTypes):
assert index is not None
self.read_counter.setdefault(
obj.dfield, npw.int_zeros(shape=(obj.nb_components,))
)[index] += count
elif isinstance(obj, self.SimpleCounterTypes):
self.read_counter.setdefault(obj, 0)
self.read_counter[obj] += 1
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
[docs]
def write(self, obj, index=None, count=1):
check_instance(count, int)
if isinstance(obj, self.IndexedCounterTypes):
assert index is not None
self.write_counter.setdefault(
obj.dfield, npw.int_zeros(shape=(obj.nb_components,))
)[index] += count
elif isinstance(obj, self.SimpleCounterTypes):
self.write_counter.setdefault(obj, 0)
self.write_counter[obj] += 1
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
[docs]
def copy(self):
edi = ExprDiscretizationInfo()
for obj, counts in self.read_counter.items():
if isinstance(obj, self.IndexedCounterTypes):
counter = edi.read_counter.setdefault(
obj, npw.int_zeros(shape=(obj.nb_components,))
)
counter += counts
elif isinstance(obj, self.SimpleCounterTypes):
edi.read_counter[obj] = counts
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
for obj, counts in self.write_counter.items():
if isinstance(obj, self.IndexedCounterTypes):
counter = edi.write_counter.setdefault(
obj, npw.int_zeros(shape=(obj.nb_components,))
)
counter += counts
elif isinstance(obj, self.SimpleCounterTypes):
edi.write_counter[obj] = counts
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
edi.push_parameters(**self.parameters)
return edi
[docs]
def update(self, other):
check_instance(other, ExprDiscretizationInfo)
for obj, counts in other.read_counter.items():
if isinstance(obj, self.IndexedCounterTypes):
counter = self.read_counter.setdefault(
obj, npw.int_zeros(shape=(obj.nb_components,))
)
counter += counts
elif isinstance(obj, self.SimpleCounterTypes):
self.read_counter.setdefault(obj, 0)
self.read_counter[obj] += counts
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
for obj, counts in other.write_counter.items():
if isinstance(obj, self.IndexedCounterTypes):
counter = self.write_counter.setdefault(
obj, npw.int_zeros(shape=(obj.nb_components,))
)
counter += counts
elif isinstance(obj, self.SimpleCounterTypes):
self.write_counter.setdefault(obj, 0)
self.write_counter[obj] += counts
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
self.push_parameters(**other.parameters)
[docs]
def push_parameters(self, *param, **kwd_params):
self.parameters.update(**kwd_params)
for p in param:
self.parameters[p.name] = param
def __iadd__(self, rhs):
check_instance(rhs, (np.integer, int))
rhs = int(rhs)
for obj, counts in self.read_counter.items():
if isinstance(obj, self.IndexedCounterTypes):
counts[counts > 0] += rhs
elif isinstance(obj, self.SimpleCounterTypes):
self.read_counter[obj] += rhs
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
for obj, counts in self.write_counter.items():
if isinstance(obj, self.IndexedCounterTypes):
counts[counts > 0] += rhs
elif isinstance(obj, self.SimpleCounterTypes):
self.write_counter[obj] += rhs
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
return self
def __imul__(self, rhs):
check_instance(rhs, (np.integer, int))
rhs = int(rhs)
for obj, counts in self.read_counter.items():
if isinstance(obj, self.IndexedCounterTypes):
counts[...] = rhs * counts
elif isinstance(obj, self.SimpleCounterTypes):
self.read_counter[obj] *= rhs
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
for obj, counts in self.write_counter.items():
if isinstance(obj, self.IndexedCounterTypes):
counts[...] = rhs * counts
elif isinstance(obj, self.SimpleCounterTypes):
self.write_counter[obj] *= rhs
else:
msg = f"Unsupported type {type(obj)}."
raise TypeError(msg)
return self
[docs]
def read_objects(self, types):
return set(filter(lambda x: isinstance(x, types), self.read_counter.keys()))
[docs]
def written_objects(self, types):
return set(filter(lambda x: isinstance(x, types), self.read_counter.keys()))
@property
def fields(self):
return set(self.read_objects(DiscreteScalarFieldView)).update(
self.written_objects(DiscreteScalarFieldView)
)
@property
def arrays(self):
return set(self.read_objects(SymbolicArray)).update(
self.written_objects(SymbolicArray)
)
@property
def buffers(self):
return set(self.read_objects(SymbolicBuffer)).update(
self.written_objects(SymbolicBuffer)
)
[docs]
class SymbolicExpressionInfo:
"""Helper class store information about parsed symbolic expressions."""
def __new__(
cls, name, exprs, dt=None, dt_coeff=None, compute_resolution=None, **kwds
):
return super().__new__(cls, **kwds)
def __init__(
self, name, exprs, dt=None, dt_coeff=None, compute_resolution=None, **kwds
):
super().__init__(**kwds)
self.name = name
self.exprs = exprs
self.kind = self.check_expressions(exprs)
# continuous part
self.domain = None
self.input_arrays = SortedDict()
self.output_arrays = SortedDict()
self.input_buffers = SortedDict()
self.output_buffers = SortedDict()
self.input_fields = SortedDict()
self.output_fields = SortedDict()
self.input_params = SortedDict()
self.output_params = SortedDict()
self.scalars = SortedDict()
self.is_volatile = set()
self.direction = None
self.has_direction = None
if self.kind is SymbolicExpressionKind.TIME_INTEGRATE:
if not isinstance(dt, ScalarParameter) or not isinstance(dt_coeff, float):
msg = "Symbolic expressions of kind TIME_INTEGRATE require two extra parameters:"
msg += "\n *dt: ScalarParameter, got dt={}."
msg += "\n *dt_coeff: float, got dt_coeff={}."
msg += "\n Please give simulation timestep as an input parameter of operator {}."
msg = msg.format(dt, dt_coeff, name)
raise RuntimeError(msg)
self.input_params[dt.name] = dt
self.dt = dt
self.dt_coeff = dt_coeff
# field requirements part
self.min_ghosts = None
self.min_ghosts_per_components = None
# discrete part
self.dexprs = None
self.input_dfields = None
self.output_dfields = None
self.inout_dfields = None
self.discretization_info = None
self.stencils = None
self.tmp_vars = None
if compute_resolution is None:
self.compute_resolution = None
self._dim = None
else:
compute_resolution = to_tuple(compute_resolution)
check_instance(compute_resolution, tuple, values=int)
self._dim = len(compute_resolution)
self.compute_resolution = compute_resolution
def _is_discretized(self):
"""Return true if the SymbolicExpressionInfo was discretized."""
return self.dexprs is not None
def _get_dim(self):
"""Shortcut to domain dimension."""
assert self._dim is not None
return self._dim
def _get_fields(self):
"""Return input and output fields."""
fields = {k: v for (k, v) in self.input_fields.items()}
fields.update(self.output_fields)
return fields
def _get_params(self):
"""Return input and output fields."""
fields = {k: v for (k, v) in self.input_params.items()}
fields.update(self.output_params)
return fields
@property
def max_granularity(self):
return self.dim - 1
dim = property(_get_dim)
fields = property(_get_fields)
params = property(_get_params)
is_discretized = property(_is_discretized)
[docs]
def check_expressions(self, exprs):
return SymbolicExpressionParser.check_expressions(exprs)
[docs]
def check_field(self, field):
"""
Check if given continuous field is compatible with previously
parsed fields and arrays.
"""
check_instance(field, Field)
if self.domain is None:
self.domain = field.domain
if self._dim is None:
self._dim = field.domain.dim
elif self._dim != field.domain.dim:
msg = "Dimension mismatch between field domain dimension and array dimension, "
msg += f"got {self._dim} and {field.domain.dim}."
raise ValueError(msg)
elif field.domain is not self.domain:
msg = "Domain mismatch for field {}:\n{}\nReference domain was:\n{}."
msg = msg.format(field.name, field.domain, self.domain)
raise ValueError(msg)
[docs]
def check_array(self, array):
"""
Check if given symbolic array is compatible with previously
parsed fields and arrays.
"""
check_instance(array, SymbolicArray)
dim = array.dim
if self.domain is not None:
if self.domain.dim != dim:
msg = "Dimension mismatch between field domain dimension and array dimension, "
msg += f"got {self.domain.dim} and {dim}."
raise ValueError(msg)
elif self._dim is not None:
if self._dim != dim:
msg = "Dimension mismatch between arrays, got {} and {}."
msg = msg.format(self._dim, dim)
raise ValueError(msg)
else:
self._dim = dim
[docs]
def discretize_expressions(
self, input_dfields, output_dfields, force_symbolic_axes
):
check_instance(input_dfields, dict, keys=Field, values=DiscreteScalarFieldView)
check_instance(output_dfields, dict, keys=Field, values=DiscreteScalarFieldView)
assert len(set(self.input_fields.keys()) - set(input_dfields.keys())) == 0
assert len(set(self.output_fields.keys()) - set(output_dfields.keys())) == 0
self.input_dfields = {
k: v for (k, v) in input_dfields.items() if (k in self.input_fields)
}
self.output_dfields = {
k: v for (k, v) in output_dfields.items() if (k in self.output_fields)
}
self.inout_dfields = {
k: v
for (k, v) in self.output_dfields.items()
if (
(k in self.input_dfields) and (self.input_dfields[k].dfield is v.dfield)
)
}
self.stencils = SortedDict()
dfields = tuple(input_dfields.values()) + tuple(output_dfields.values())
if force_symbolic_axes is not None:
if isinstance(force_symbolic_axes, tuple):
axes = force_symbolic_axes
else:
axes = None
elif dfields:
axes = dfields[0].tstate.axes
for dfield in dfields:
if dfield.tstate.axes != axes:
msg = "Discrete field {} has a topology state axes mismatch {} "
msg += "with reference axes {}."
msg = msg.format(dfield.name, dfield.tstate.axes, axes)
raise RuntimeError(msg)
else:
msg = "No discrete fields found in custom symbolic operator."
raise RuntimeError(msg)
self.axes = axes
SymbolicExpressionParser.discretize_expressions(self)
self.check_dfield_sizes()
[docs]
def setup_expressions(self, work):
SymbolicExpressionParser.setup_expressions(self, work)
[docs]
def check_dfield_sizes(self):
dfields = set(self.input_dfields.values()).union(self.output_dfields.values())
if len(dfields) > 0:
dfield0 = next(iter(dfields))
compute_resolution = first_not_None(
self.compute_resolution, dfield0.compute_resolution
)
for dfield in dfields:
if (dfield.compute_resolution != compute_resolution).any():
msg = "Mismatching compute resolution {}::{} vs {}::{}."
msg = msg.format(
dfield.name,
dfield.compute_resolution,
dfield0.name,
dfield0.compute_resolution,
)
raise ValueError(msg)
compute_resolution = tuple(compute_resolution)
self.compute_resolution = compute_resolution
[docs]
def check_arrays(self):
compute_resolution = self.compute_resolution
arrays = set(self.input_arrays.values()).union(self.output_arrays.values())
for a in arrays:
if not a.is_bound:
msg = "FATAL ERROR: {}::{} has not been bound to any memory "
msg += "prior to setup."
msg = msg.format(type(a).__name__, a.name)
raise RuntimeError(msg)
dim = a.dim
shape = a.shape
if len(shape) != dim:
msg = "FATAL ERROR: {}::{} array shape does not match array dimension."
msg = msg.format(type(a).__name__, a.name)
raise RuntimeError(msg)
if compute_resolution is None:
compute_resolution = shape
elif not npw.array_equal(shape, compute_resolution):
msg = "FATAL ERROR: {}::{} array shape {} does not comply with determined "
msg += "compute resolution {}."
msg = msg.format(type(a).__name__, a.name, shape, compute_resolution)
raise RuntimeError(msg)
if compute_resolution is None:
msg = "FATAL ERROR: Something went wrong while determining compute_resolution."
raise RuntimeError(msg)
self.compute_resolution = tuple(compute_resolution)
[docs]
def check_buffers(self):
buffers = set(self.input_buffers.values()).union(self.output_buffers.values())
for b in buffers:
if not b.is_bound:
msg = "FATAL ERROR: {}::{} has not been bound to any memory "
msg += "prior to setup."
msg = msg.format(type(b).__name__, b.name)
raise RuntimeError(msg)
[docs]
def determine_direction(self, *variables):
assert self.domain is not None
direction = self.direction
for var in variables:
if var in space_symbols:
if direction is None:
direction = space_symbols.index(var)
else:
xd = space_symbols[direction]
if xd != var:
msg = "Expression already contained a derivative with respect to {} (direction {})."
msg += "\nFound a new derivative direction which is not compatible with the current "
msg += "one."
msg += "\nCannot differentiate with respect to {}."
msg = msg.format(xd, direction, var)
raise ValueError(msg)
self.direction = direction
def __str__(self):
msg = """
::SymbolicExpressionInfo::
expression kind: {}
continuous expressions:{}
input_fields: {}
output_fields: {}
input_arrays: {}
output_arrays: {}
input_buffers: {}
output_buffers: {}
input_params: {}
output_params: {}
discretizations:{}
""".format(
self.kind,
"\n" + "\n".join(f" {i}/ {e}" for i, e in enumerate(self.exprs)),
(
", ".join(f"{f.name}" for f in self.input_fields.keys())
if self.input_fields
else "none"
),
(
", ".join(f"{f.name}" for f in self.output_fields.keys())
if self.output_fields
else "none"
),
(
", ".join(f"{f}" for f in self.input_arrays.keys())
if self.input_arrays
else "none"
),
(
", ".join(f"{f}" for f in self.output_arrays.keys())
if self.output_arrays
else "none"
),
(
", ".join(f"{f}" for f in self.input_buffers.keys())
if self.input_buffers
else "none"
),
(
", ".join(f"{f}" for f in self.output_buffers.keys())
if self.output_buffers
else "none"
),
(
", ".join(f"{p}" for p in self.input_params.keys())
if self.input_params
else "none"
),
(
", ".join(f"{p}" for p in self.output_params.keys())
if self.output_params
else "none"
),
"\n"
+ "\n".join(
" {}: {}".format(
f.name,
(d.short_description() if isinstance(d, TopologyView) else d),
)
for (f, d) in self.fields.items()
),
)
if self.min_ghosts:
msg += """
min_ghosts_per_components:{}
min_ghosts:{}
""".format(
"\n"
+ "\n".join(
" {}/ [{}]".format(f.name, ", ".join(str(x) for x in gpc))
for f, gpc in self.min_ghosts_per_components.items()
),
"\n"
+ "\n".join(
" {}/ [{}]".format(f.name, ", ".join(str(x) for x in g))
for f, g in self.min_ghosts.items()
),
)
if self.is_discretized:
msg += """
discretized expressions:{}
read_counter: {}
write_counter: {}
""".format(
"\n" + "\n".join(f" {i}/ {e}" for i, e in enumerate(self.dexprs)),
(
", ".join(
f"{f.name}{self.discretization_info.read_counter[f.dfield]}"
for f in self.input_dfields.values()
)
if self.input_dfields
else "none"
),
(
", ".join(
f"{f.name}{self.discretization_info.write_counter[f.dfield]}"
for f in self.output_dfields.values()
)
if self.output_dfields
else "none"
),
)
return msg
def _get_fields(self):
"""Return all fields as a set."""
return set(self._input_fields.values()).update(self._output_fields.values())
@property
def extracted_exprs(self):
return SymbolicExpressionParser.extract_expressions(self.exprs)
@property
def extracted_dexprs(self):
return SymbolicExpressionParser.extract_expressions(self.dexprs)
[docs]
class SymbolicExpressionParser:
"""Helper class to parse symbolic expressions."""
[docs]
@classmethod
def check_expressions(cls, exprs):
kind = None
fields = SortedDict()
arrays = SortedDict()
exprs = tuple(
filter(
lambda e: isinstance(e, ValidExpressions),
cls.extract_expressions(exprs),
)
)
for expr in exprs:
check_instance(expr, ValidExpressions)
lhs = expr.args[0]
field = None
array = None
if isinstance(lhs, TmpScalar):
continue
elif isinstance(lhs, IndexedBuffer):
pass
elif isinstance(lhs, (AppliedSymbolicField, SymbolicArray)):
if kind is None:
kind = SymbolicExpressionKind.AFFECT
elif kind is not SymbolicExpressionKind.AFFECT:
msg = "Symbolic expression kind was set to be {} but found "
msg += " an expression that is of kind {}.\n expr: {}"
msg = msg.format(kind, SymbolicExpressionKind.AFFECT, expr)
raise ValueError(msg)
if isinstance(lhs, AppliedSymbolicField):
field = lhs
else:
array = lhs
elif isinstance(lhs, sm.Derivative):
_vars = get_derivative_variables(lhs)
_t = time_symbol
unique_vars = set(_vars)
if isinstance(lhs.args[0], AppliedSymbolicField):
field = lhs.args[0]
elif isinstance(lhs.args[0], SymbolicArray):
msg = "Assignment LHS cannot be a derivative of a SymbolicArray "
msg += "because ghosts are not handled as symbolic Fields."
raise TypeError(msg)
else:
msg = "Assignment LHS cannot be a derivative of a {}."
msg = msg.format(type(lhs))
raise TypeError(msg)
if (_t not in unique_vars) or len(unique_vars) != 1:
msg = "Assignment LHS can only be a derivative of time {}, got {}."
msg = msg.format(_t, ", ".join(str(x) for x in unique_vars))
raise TypeError(msg)
if kind is None:
kind = SymbolicExpressionKind.TIME_INTEGRATE
elif kind is not SymbolicExpressionKind.TIME_INTEGRATE:
msg = "Symbolic expression kind was set to be {} but found "
msg += " an expression that is of kind {}.\n expr: {}"
msg = msg.format(kind, SymbolicExpressionKind.TIME_INTEGRATE, expr)
raise ValueError(msg)
else:
msg = f"Assignment LHS cannot be of type {type(lhs)}."
raise TypeError(msg)
if field is not None:
assert isinstance(field, AppliedSymbolicField)
index = field.index
field = field.field
key = (field, index)
if key in fields:
msg = "Field {} was already written by expression\n"
msg += "{}\ncannot write it again in expression\n"
msg += "{}\nFATAL ERROR: Invalid expressions."
msg = msg.format(field.name, fields[key], expr)
raise ValueError(msg)
fields[key] = expr
if array is not None:
assert isinstance(array, SymbolicArray)
key = array
if key in arrays:
msg = "Array {} was already written by expression\n"
msg += "{}\ncannot write it again in expression\n"
msg += "{}\nFATAL ERROR: Invalid expressions."
msg = msg.format(array.name, arrays[key], expr)
raise ValueError(msg)
arrays[key] = expr
if kind is None:
kind = SymbolicExpressionKind.AFFECT
return kind
[docs]
@classmethod
def parse(cls, name, variables, *exprs, **kwds):
preferred_direction = first_not_None(kwds.pop("preferred_direction", None), 0)
info = SymbolicExpressionInfo(name, exprs, **kwds)
for expr in cls.extract_expressions(exprs):
cls.parse_one(variables, info, expr)
if info._dim is None:
msg = "\n\nFATAL ERROR: Neither SymbolicFields nor SymbolicArrays were present in parsed "
msg += "symbolic expressions and compute_resolution has not been specified."
msg += "\nAt least one is needed to deduce the shape of the compute kernel."
msg += "\n"
msg += "\nExpressions were:"
for i, e in enumerate(exprs):
msg += f"\n {i:2>}/ {e}"
msg += "\n"
raise RuntimeError(msg)
if info.direction is None:
info.direction = preferred_direction
return info
[docs]
@classmethod
def parse_one(cls, variables, info, expr):
if isinstance(expr, Assignment):
cls.parse_assignment(variables, info, *expr.args)
else:
try:
cls.parse_subexpr(variables, info, expr)
except:
msg = "Failed to parse symbolic expression type {}."
print()
print(msg.format(type(expr)))
print()
raise
[docs]
@classmethod
def parse_assignment(cls, variables, info, lhs, rhs):
if isinstance(
lhs, (AppliedSymbolicField, SymbolicArray, IndexedBuffer, TmpScalar)
):
cls.write(variables, info, lhs)
cls.parse_subexpr(variables, info, rhs)
if isinstance(lhs, IndexedBuffer):
cls.parse_subexpr(lhs.index, info, rhs)
elif isinstance(lhs, sm.Derivative):
f = lhs.args[0]
cls.read(variables, info, f)
cls.write(variables, info, f)
cls.parse_subexpr(variables, info, rhs)
else:
msg = "Unknown expression type {}.\n __mro__ = {}\nExpression is: {}\n"
msg = msg.format(type(lhs), type(lhs).__mro__, lhs)
raise NotImplementedError(msg)
[docs]
@classmethod
def parse_subexpr(cls, variables, info, expr):
if isinstance(expr, npw.ndarray):
assert expr.ndim == 0, expr
expr = expr.tolist()
if isinstance(expr, (str, int, float, complex, npw.number)):
return
elif isinstance(
expr, (AppliedSymbolicField, SymbolicScalarParameter, SymbolicArray)
):
cls.read(variables, info, expr)
elif isinstance(expr, Cast):
cls.parse_subexpr(variables, info, expr.expr)
elif isinstance(expr, MutexLock):
var = expr.mutexes
cls.read(variables, info, var)
cls.write(variables, info, var)
info.is_volatile.add(var.name)
elif isinstance(expr, MutexUnlock):
var = expr.mutexes
cls.write(variables, info, var)
info.is_volatile.add(var.name)
elif isinstance(expr, sm.Derivative):
dvars = get_derivative_variables(expr)
info.determine_direction(*dvars)
cls.parse_subexpr(variables, info, expr.args[0])
elif isinstance(expr, (sm.Expr, sm.Rel)):
for e in expr.args:
cls.parse_subexpr(variables, info, e)
else:
msg = "Unknown expression type {}.\n __mro__ = {}\nExpression is: {}\n"
msg = msg.format(type(expr), type(expr).__mro__, expr)
raise NotImplementedError(msg)
[docs]
@classmethod
def write(cls, variables, info, var):
if isinstance(var, TmpScalar):
info.scalars[var.varname] = var
elif isinstance(var, IndexedBuffer):
cls.write(variables, info, var.indexed_object)
elif isinstance(var, SymbolicArray):
array = var
info.check_array(array)
if array.name not in info.output_arrays:
info.output_arrays[array.name] = array
else:
assert info.output_arrays[array.name] is array
elif isinstance(var, SymbolicBuffer):
buf = var
if buf.name not in info.output_buffers:
info.output_buffers[buf.name] = buf
else:
assert info.output_buffers[buf.name] is buf
elif isinstance(var, AppliedSymbolicField):
field = var.field
info.check_field(field)
if field not in variables:
msg = (
"Field {} is written but no discretization was given in variables."
)
msg = msg.format(field.name)
raise ValueError(msg)
if field not in info.output_fields:
info.output_fields[field] = variables[field]
elif isinstance(var, SymbolicScalarParameter):
param = var.parameter
pname = param.name
if param.const:
msg = "FATAL ERROR: Cannot assign value to constant parameter {}."
msg = msg.format(pname)
raise ValueError(msg)
elif (pname in info.output_params) and (
info.output_params[pname] is not param
):
msg = "Incompatible parameter names {}."
msg = msg.format(pname)
raise ValueError(msg)
info.output_params[pname] = param
else:
msg = "Unknown written variable type {}.\n __mro__ = {}\n"
msg = msg.format(type(var), type(var).__mro__)
raise NotImplementedError(msg)
[docs]
@classmethod
def read(cls, variables, info, var, offset=None):
if isinstance(var, IndexedBuffer):
cls.read(variables, info, var.indexed_object)
elif isinstance(var, AppliedSymbolicField):
field = var.field
info.check_field(field)
if field not in variables:
msg = "Field {} is read but no discretization was given in variables."
msg = msg.format(field.name)
raise ValueError(msg)
if field not in info.input_fields:
info.input_fields[field] = variables[field]
elif isinstance(var, SymbolicArray):
array = var
info.check_array(array)
if array not in info.input_arrays:
info.input_arrays[array.name] = array
else:
assert info.input_arrays[array.name] is array
elif isinstance(var, SymbolicBuffer):
buf = var
if buf not in info.input_buffers:
info.input_buffers[buf.name] = buf
else:
assert info.input_buffers[buf.name] is buf
elif isinstance(var, SymbolicScalarParameter):
param = var.parameter
if param.name in info.input_params:
assert (
info.input_params[param.name] is param
), "Incompatible parameter names."
else:
info.input_params[param.name] = param
else:
msg = "Unknown read variable type {}.\n __mro__ = {}\n"
msg = msg.format(type(var), type(var).__mro__)
raise NotImplementedError(msg)
@classmethod
def _extract_obj_requirements(cls, info, expr):
if isinstance(expr, npw.ndarray):
assert expr.ndim == 0
expr = expr.tolist()
if isinstance(
expr, (int, sm.Integer, float, complex, sm.Rational, sm.Float, npw.number)
):
return {}
elif isinstance(expr, Cast):
return cls._extract_obj_requirements(info, expr.expr)
elif isinstance(expr, SymbolicArray):
return {expr: expr.new_requirements()}
elif isinstance(expr, AppliedSymbolicField):
field = expr.field
index = expr.index
return {
(field, index): DiscreteFieldRequirements(
operator=None, variables=None, field=field, _register=False
)
}
elif isinstance(expr, str):
return {}
elif isinstance(expr, sm.Derivative):
dexpr = expr.args[0]
dvars = get_derivative_variables(expr)
unique_dvars = set(dvars)
invalid_dvars = unique_dvars - set(space_symbols) - {time_symbol}
if invalid_dvars:
msg = "Cannot differentiate with respect to variable(s) {}."
msg = msg.format(", ".join(str(x) for x in invalid_dvars))
msg += "\nOnly allowed variables are: {}".format(
" ,".join(str(x) for x in space_symbols)
)
raise ValueError(msg)
direction = info.direction
if direction is not None:
xd = space_symbols[direction]
if unique_dvars - {xd}:
msg = (
"Expression already contained a derivative with respect to {} "
)
msg += "(direction {}, {}-axis)."
msg += "\nFound a new derivative direction which is not compatible "
msg += "with the current one."
msg += "\nCannot differentiate with respect to {}."
msg = msg.format(
xd,
direction,
DirectionLabels[direction],
", ".join(str(x) for x in (unique_dvars - {xd})),
)
raise RuntimeError(msg)
else:
if len(unique_dvars) > 1:
msg = "Cannot differentiate on different variables at a time: {}"
msg = msg.format(", ".join(str(x) for x in unique_dvars))
raise ValueError(msg)
xd = dvars[0]
assert xd in space_symbols, xd
direction = space_symbols.index(xd)
info.direction = direction
derivative = len(dvars)
order = info.space_discretization
dxd = dspace_symbols[direction]
assert order > 0, order
assert order % 2 == 0, order
csg = CenteredStencilGenerator()
csg.configure(dim=1, dtype=MPQ, derivative=derivative)
stencil = csg.generate_exact_stencil(order=order)
min_ghosts = max(stencil.L, stencil.R)
obj_reqs = cls._extract_obj_requirements(info, dexpr)
for obj, req in obj_reqs.items():
req.min_ghosts[-1 - direction] += min_ghosts
return obj_reqs
elif isinstance(expr, Assignment):
lhs, rhs = expr.args
if isinstance(lhs, sm.Derivative):
assert len(lhs.args) == 2
try:
assert lhs.args[1] == time_symbol
except: # sympy version >= 1.2
assert lhs.args[1][0] == time_symbol
assert lhs.args[1][1] == 1
lhs = lhs.args[0]
freqs = cls._extract_obj_requirements(info, rhs)
return freqs
elif isinstance(expr, (sm.Expr, sm.Rel)):
obj_requirements = SortedDict()
for e in expr.args:
obj_reqs = cls._extract_obj_requirements(info, e)
for obj, reqs in obj_reqs.items():
if obj in obj_requirements:
obj_requirements[obj].update_requirements(reqs)
else:
obj_requirements[obj] = reqs
return obj_requirements
else:
msg = "Unknown expression type {}.\n __mro__ = {}\n"
msg = msg.format(type(expr), type(expr).__mro__)
raise NotImplementedError(msg)
[docs]
@classmethod
def discretize_expressions(cls, info):
check_instance(info, SymbolicExpressionInfo)
dexprs = ()
discretization_info = ExprDiscretizationInfo()
for expr in info.exprs:
dexpr, di = cls.discretize_one(info, expr)
dexprs += (dexpr,)
discretization_info.update(di)
info.dexprs = dexprs
info.discretization_info = discretization_info
[docs]
@classmethod
def discretize_one(cls, info, expr):
return cls.discretize_subexpr(info, expr)
[docs]
@classmethod
def discretize_assignment(cls, info, expr):
msg = "Unsupported, use Assignment instead of {}."
msg = msg.format(type(expr).__name__)
assert not isinstance(expr, AugmentedAssignment), msg
lhs, rhs = expr.args
rhs, di = cls.discretize_subexpr(info, rhs)
if isinstance(
lhs,
(
AppliedSymbolicField,
SymbolicArray,
IndexedBuffer,
TmpScalar,
),
):
func = expr.func
elif isinstance(lhs, sm.Derivative):
assert isinstance(lhs.args[0], AppliedSymbolicField)
assert len(lhs.args) == 2
try:
assert lhs.args[1] == time_symbol
except: # sympy version >= 1.2
assert lhs.args[1][0] == time_symbol
assert lhs.args[1][1] == 1
lhs = lhs.args[0]
assert expr.func is Assignment
func = lambda *args: TimeIntegrate(info.time_integrator, *args)
dfield = info.input_dfields[lhs.field]
cls.read_discrete(info, lhs, dfield, di)
else:
msg = "Invalid symbolic assignment lhs type {}."
msg = msg.format(type(lhs))
raise NotImplementedError(msg)
if isinstance(lhs, AppliedSymbolicField):
field, index, indexed_field = lhs.field, lhs.index, lhs.indexed_field
dfield = info.output_dfields[field]
lhs = dfield.s[index]
check_instance(lhs, SymbolicDiscreteField)
cls.write_discrete(info, lhs, dfield, di)
elif isinstance(lhs, IndexedBuffer):
di.write(lhs.indexed_object)
index, edi = cls.discretize_subexpr(info, lhs.index)
di.update(edi)
lhs = lhs.func(lhs.indexed_object, index)
elif isinstance(lhs, TmpScalar):
info.scalars[lhs.varname] = lhs
else:
di.write(lhs)
new_expr = func(lhs, rhs)
return new_expr, di
[docs]
@classmethod
def write_discrete(cls, info, expr, dfield, di):
index = expr.index
di.write(dfield, index, 1)
[docs]
@classmethod
def read_discrete(cls, info, expr, dfield, di):
index = expr.index
di.read(dfield, index, 1)
[docs]
@classmethod
def discretize_subexpr(cls, info, expr):
di = ExprDiscretizationInfo()
if isinstance(expr, (list, tuple, set, npw.ndarray)):
texpr = type(expr)
E = ()
for e in expr:
e, edi = cls.discretize_subexpr(info, e)
di.update(edi)
E += (e,)
if texpr in (list, tuple, set):
expr = texpr(E)
else:
expr = list(E)
return expr, di
elif isinstance(expr, (int, float, complex, npw.number)):
return expr, di
elif cls.should_transpose_expr(info, expr):
expr = cls.transpose_expr(info, expr)
return expr, di
elif isinstance(expr, Assignment):
return cls.discretize_assignment(info, expr)
elif isinstance(expr, Cast):
e, edi = cls.discretize_subexpr(info, expr.expr)
di.update(edi)
return expr.func(e, *expr.args[1:]), di
elif isinstance(expr, MutexLock):
var = expr.mutexes
di.read(var)
di.write(var)
args, edi = cls.discretize_subexpr(info, expr.args[1:])
expr = expr.func(var, *args)
di.update(edi)
return expr, di
elif isinstance(expr, MutexUnlock):
var = expr.mutexes
di.write(var)
args, edi = cls.discretize_subexpr(info, expr.args[1:])
expr = expr.func(var, *args)
di.update(edi)
return expr, di
elif isinstance(expr, TmpScalar):
return expr, di
elif isinstance(expr, str):
return expr, di
elif isinstance(expr, SymbolicScalarParameter):
di.push_parameters(expr.parameter)
return expr, di
elif isinstance(expr, (SymbolicArray, SymbolicBuffer)):
di.read(expr)
return expr, di
elif isinstance(expr, AppliedSymbolicField):
indexed_field = expr.indexed_field
index, field = expr.index, expr.field
dfield = info.input_dfields[field]
cls.read_discrete(info, expr, dfield, di)
return dfield.s[index], di
elif isinstance(expr, sm.Derivative):
dexpr = expr.args[0]
dvars = get_derivative_variables(expr)
unique_dvars = set(dvars)
invalid_dvars = unique_dvars - set(space_symbols)
direction = info.direction
xd = dvars[0]
derivative = len(dvars)
order = info.space_discretization
assert not invalid_dvars
assert xd in space_symbols, xd
assert xd == space_symbols[direction]
assert len(unique_dvars) == 1
assert order > 0, order
assert order % 2 == 0, order
csg = CenteredStencilGenerator()
csg.configure(dim=1, dtype=MPQ, derivative=derivative)
stencil = csg.generate_exact_stencil(order=order)
dexpr, di = cls.discretize_subexpr(info, dexpr)
di += stencil.non_zero_coefficients() - 1
expr = ApplyStencil(dexpr, stencil)
info.stencils[stencil] = di.copy()
return expr, di
elif isinstance(expr, (sm.Expr, sm.Rel)):
new_args = ()
for e in expr.args:
arg, edi = cls.discretize_subexpr(info, e)
di.update(edi)
new_args += (arg,)
if new_args:
try:
expr = expr.func(*new_args)
except:
msg = "Failed to build a {} from arguments {}."
msg = msg.format(expr.func, new_args)
print()
print(msg)
print()
raise
return expr, di
else:
return expr, di
else:
msg = "Unknown expression type {}.\n __mro__ = {}\n"
msg = msg.format(type(expr), type(expr).__mro__)
raise NotImplementedError(msg)
[docs]
@classmethod
def setup_expressions(cls, info, work):
check_instance(info, SymbolicExpressionInfo)
for dexpr in info.dexprs:
cls.setup_one(dexpr, work)
[docs]
@classmethod
def setup_one(cls, dexpr, work):
for atom in dexpr.atoms(SetupExprI):
atom.setup(work)
[docs]
@classmethod
def transposable_expressions(cls):
return (
space_symbols,
dspace_symbols,
local_indices_symbols,
global_indices_symbols,
)
[docs]
@classmethod
def should_transpose_expr(cls, info, expr):
return any(expr in te for te in cls.transposable_expressions())
[docs]
@classmethod
def transpose_expr(cls, info, expr):
axes = info.axes
if axes is None:
return expr
dim = len(axes)
assert isinstance(axes, tuple)
assert cls.should_transpose_expr(info, expr)
assert len(set(axes)) == dim
symbols = None
for te in cls.transposable_expressions():
if expr in te:
symbols = te[:dim]
break
assert symbols is not None
i = symbols.index(expr)
return symbols[axes[i]]
[docs]
class CustomSymbolicOperatorBase(DirectionalOperatorBase, metaclass=ABCMeta):
"""
Common implementation interface for custom symbolic (code generated) operators.
"""
__default_method = {
ComputeGranularity: 0,
SpaceDiscretization: 2,
TimeIntegrator: Euler,
MultiScaleInterpolation: Interpolation.LINEAR,
}
__available_methods = {
ComputeGranularity: InstanceOf(int),
SpaceDiscretization: InstanceOf(int),
TimeIntegrator: InstanceOf(ExplicitRungeKutta),
MultiScaleInterpolation: Interpolation.LINEAR,
}
[docs]
@classmethod
def default_method(cls):
dm = super().default_method()
dm.update(cls.__default_method)
return dm
[docs]
@classmethod
def available_methods(cls):
am = super().available_methods()
am.update(cls.__available_methods)
return am
[docs]
@debug
def handle_method(self, method):
super().handle_method(method)
cr = method.pop(ComputeGranularity)
space_discretization = method.pop(SpaceDiscretization)
time_integrator = method.pop(TimeIntegrator)
interpolation = method.pop(MultiScaleInterpolation)
assert 0 <= cr <= self.expr_info.max_granularity, cr
assert 2 <= space_discretization, space_discretization
assert space_discretization % 2 == 0, space_discretization
self._expr_info.compute_granularity = cr
self._expr_info.time_integrator = time_integrator
self._expr_info.interpolation = interpolation
self._expr_info.space_discretization = space_discretization
@debug
def __new__(
cls,
name,
exprs,
variables,
splitting_direction=None,
splitting_dim=None,
dt_coeff=None,
dt=None,
time=None,
**kwds,
):
return super().__new__(
cls,
name=name,
input_fields=None,
output_fields=None,
input_params=None,
output_params=None,
input_tensor_fields=None,
output_tensor_fields=None,
splitting_direction=splitting_direction,
splitting_dim=splitting_dim,
dt_coeff=dt_coeff,
**kwds,
)
@debug
def __init__(
self,
name,
exprs,
variables,
splitting_direction=None,
splitting_dim=None,
dt_coeff=None,
dt=None,
time=None,
**kwds,
):
"""
Initialize a CustomSymbolicOperatorBase.
Expressions are parsed and input/output vars are extracted.
Parameters
----------
exprs: array_like of valid hysop.symbolic.Expr
Expressions that will generate code.
Valid expressions are defined as hysop.operator.base.custom_symbolic_operator.ValidExpressions.
variables: dict
dictionary of fields as keys aned topologies as values.
splitting_direction: int
Expected direction of derivatives in given expression.
splitting_dim: int
Only used in directional splittings.
dt_coeff: float
Only used in directional splittings.
dt: ScalarParameter
Only used for integration.
kwds:
Base class keyword arguments.
Notes
-----
All input and output fields and parameters are directly
extracted from expression analysis.
"""
check_instance(variables, dict, keys=Field, values=CartesianTopologyDescriptors)
check_instance(exprs, tuple, values=ValidExpressions, minsize=1)
check_instance(splitting_direction, int, allow_none=True)
check_instance(splitting_dim, int, allow_none=True)
check_instance(dt_coeff, float, allow_none=True)
check_instance(dt, ScalarParameter, allow_none=True)
if (splitting_dim is None) ^ (dt_coeff is None):
msg = "splitting_dim and dt_coeff should be specified in the same time."
raise ValueError(msg)
dt_coeff = first_not_None(dt_coeff, 1.0)
# Expand tensor fields to scalar fields
scalar_variables = {
sfield: topod
for (tfield, topod) in variables.items()
for sfield in tfield.fields
}
expr_info = SymbolicExpressionParser.parse(
name,
scalar_variables,
*exprs,
dt=dt,
dt_coeff=dt_coeff,
preferred_direction=splitting_direction,
)
if splitting_direction is not None:
assert expr_info.direction == splitting_direction
splitting_direction = expr_info.direction
splitting_dim = first_not_None(splitting_dim, expr_info.domain.dim)
if expr_info.direction != splitting_direction:
msg = "Direction mismatch, expression has derivative in direction {} but direction {} "
msg += "has been specified."
msg = msg.format(expr_info.direction, splitting_direction)
raise ValueError(msg)
input_fields = expr_info.input_fields
output_fields = expr_info.output_fields
input_params = set(expr_info.input_params.values())
output_params = set(expr_info.output_params.values())
input_tensor_fields = ()
output_tensor_fields = ()
for tfield in filter(lambda x: x.is_tensor, variables.keys()):
if all((f in input_fields) for f in tfield.fields):
input_tensor_fields += (tfield,)
if all((f in output_fields) for f in tfield.fields):
output_tensor_fields += (tfield,)
super().__init__(
name=name,
input_fields=input_fields,
output_fields=output_fields,
input_params=input_params,
output_params=output_params,
input_tensor_fields=input_tensor_fields,
output_tensor_fields=output_tensor_fields,
splitting_direction=splitting_direction,
splitting_dim=splitting_dim,
dt_coeff=dt_coeff,
**kwds,
)
self._expr_info = expr_info
def _get_expr_info(self):
"""Get information about parsed symbolic expressions."""
return self._expr_info
expr_info = property(_get_expr_info)
[docs]
@debug
def get_field_requirements(self):
"""Extract field requirements from first expression parsing stage."""
requirements = super().get_field_requirements()
expr_info = self.expr_info
expr_info.extract_obj_requirements()
dim = expr_info.domain.dim
field_reqs = expr_info.field_requirements
array_reqs = expr_info.array_requirements
direction = expr_info.direction
has_direction = expr_info.has_direction
if has_direction:
assert 0 <= direction < dim
axes = TranspositionState[dim].filter_axes(
lambda axes: (axes[-1] == dim - 1 - direction)
)
axes = tuple(axes)
min_ghosts_per_components = SortedDict()
for fields, is_input, iter_requirements in zip(
(self.input_fields, self.output_fields),
(True, False),
(
requirements.iter_input_requirements,
requirements.iter_output_requirements,
),
):
if not fields:
continue
for field, td, req in iter_requirements():
min_ghosts = npw.int_zeros(shape=(field.nb_components, field.dim))
if has_direction:
req.axes = axes
for index in range(field.nb_components):
fname = f"{field.name}_{index}"
G = expr_info.min_ghosts_per_field_name.get(fname, 0)
if (field, index) in field_reqs:
fi_req = field_reqs[(field, index)]
if fi_req.axes:
if req.axes is not None:
assert set(fi_req.axes).intersection(req.axes)
req.axes = tuple(
set(req.axes).intersection(fi_req.axes)
)
else:
req.axes = fi_req.axes
_min_ghosts = fi_req.min_ghosts.copy()
_max_ghosts = fi_req.max_ghosts.copy()
assert _min_ghosts[dim - 1 - direction] <= G
assert _max_ghosts[dim - 1 - direction] >= G
_min_ghosts[dim - 1 - direction] = G
req.min_ghosts = npw.maximum(_min_ghosts, req.min_ghosts)
req.max_ghosts = npw.minimum(_max_ghosts, req.max_ghosts)
min_ghosts[index] = _min_ghosts.copy()
else:
req.min_ghosts[dim - 1 - direction] = max(
G, req.min_ghosts[dim - 1 - direction]
)
min_ghosts[index][dim - 1 - direction] = G
assert req.min_ghosts[dim - 1 - direction] >= G
assert req.max_ghosts[dim - 1 - direction] >= G
if field not in min_ghosts_per_components:
min_ghosts_per_components[field] = min_ghosts
expr_info.min_ghosts = {
k: npw.max(v, axis=0) for (k, v) in min_ghosts_per_components.items()
}
expr_info.min_ghosts_per_components = {
field: gpc[:, -1 - direction]
for (field, gpc) in min_ghosts_per_components.items()
}
for array, reqs in array_reqs:
expr_info.min_ghosts[array] = reqs.min_ghosts.copy()
expr_info.min_ghosts_per_components = reqs.min_ghosts[-1 - direction]
return requirements
[docs]
@debug
def discretize(self, force_symbolic_axes=None):
"""Discretize variables and symbolic expressions."""
if self.discretized:
return
super().discretize()
self._expr_info.discretize_expressions(
input_dfields=self.input_discrete_fields,
output_dfields=self.output_discrete_fields,
force_symbolic_axes=force_symbolic_axes,
)
[docs]
@debug
def setup(self, work):
"""Setup required work."""
self._expr_info.check_arrays()
self._expr_info.check_buffers()
super().setup(work)
if work is None:
raise ValueError("work is None.")